# import argparse
# from threading import Thread
# import torch
# import gradio as gr
# from PIL import ImageDraw, Image
# import re
# from torchvision.transforms.v2 import Resize
# import cv2
# import time
# import queue
# import logging
# from moviepy.editor import ImageSequenceClip
# import random
# from ultralytics import YOLO



# from huggingface_hub import login
# # login("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")
# yolo_model = YOLO("yolov8s-world.pt")
# # Set up logging
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)

# key_frame_list = []

# parser = argparse.ArgumentParser()
# parser.add_argument("--cpu", action="store_true")
# args = parser.parse_args()



# device = torch.device("cuda:0")
# dtype = torch.float32


# from transformers import BitsAndBytesConfig, PaliGemmaForConditionalGeneration, AutoProcessor

# # bitsand bytes not suitable for windows
# # bnb_config = BitsAndBytesConfig(
# #     load_in_4bit=True,
# #     bnb_4bit_quant_type="nf4",
# #     bnb_4bit_compute_dtype=torch.bfloat16,
# # )

# model_id = "google/paligemma-3b-mix-224"
# # model_id = r"D:\paligemma_3b_ft_science_qa_224"
# model = PaliGemmaForConditionalGeneration.from_pretrained(
#     model_id,
#     # quantization_config=bnb_config,
#     device_map=device,
# )
# processor = AutoProcessor.from_pretrained(model_id)

# def answer_question(img, prompt):
#     inputs = processor(prompt, img, return_tensors="pt").to(device)
#     output = model.generate(**inputs, max_new_tokens=100)
#     asw = processor.decode(output[0], skip_special_tokens=True)[len(prompt):]
#     bbox = extract_bbox(asw)
#     return asw, bbox



# def extract_floats(text):
#     pattern = r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]"
#     match = re.search(pattern, text)
#     if match:
#         return [float(num) for num in match.groups()]
#     return None
# logger.info(f"Initialize extract_floats finished!")

# def extract_bbox(text):
#     bbox = None
#     if extract_floats(text) is not None:
#         x1, y1, x2, y2 = extract_floats(text)
#         bbox = (x1, y1, x2, y2)
#     return bbox
# logger.info(f"Initialize extract_bbox finished!")

# def process_answer(img, bbox):
#     if bbox is not None:
#         x1, y1, x2, y2 = bbox
#         draw_image = Resize(768)(img)
#         width, height = draw_image.size
#         x1, x2 = int(x1 * width), int(x2 * width)
#         y1, y2 = int(y1 * height), int(y2 * height)
#         bbox = (x1, y1, x2, y2)
#         ImageDraw.Draw(draw_image).rectangle(bbox, outline="red", width=3)
#         return gr.update(visible=True, value=draw_image)
#     return gr.update(visible=False, value=None)


# with gr.Blocks() as demo:
#     gr.Markdown(
#         """
#         # 🌔 视觉-语言多模态安全检测
#         """
#     )

#     gr.Markdown(
#         """
#         ---
#         ## 处理图片
#         """
#     )

#     with gr.Row():
#         prompt = gr.Textbox(label="提示词输入", placeholder="在这里输入您的问题...", scale=4)
#         submit = gr.Button("提交")
#     with gr.Row():
#         img = gr.Image(type="pil", label="上传一个图片")
#         with gr.Column(): 
#             output = gr.Markdown(label="回复结果")
#             ann = gr.Image(visible=False, label="图像标注结果")

#     submit.click(answer_question, [img, prompt], [output, ann])
#     prompt.submit(answer_question, [img, prompt], [output, ann])

#     output.change(process_answer, [img, output], ann, show_progress=True)

# demo.queue().launch(debug=True)



import datetime
import functools
import glob
import json
import logging
import os
import time

import gradio as gr
import jax
import PIL.Image
import gradio_helpers
import models
import paligemma_parse

INTRO_TEXT = """🤲 PaliGemma demo\n\n
| [Paper](https://arxiv.org/abs/2407.07726)
| [GitHub](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) 
| [HF blog post](https://huggingface.co/blog/paligemma) 
| [Google blog post](https://developers.googleblog.com/en/gemma-family-and-toolkit-expansion-io-2024)
| [Vertex AI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363) 
| [Demo](https://huggingface.co/spaces/google/paligemma) 
|\n\n
[PaliGemma](https://ai.google.dev/gemma/docs/paligemma) is an open vision-language model by Google, 
inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and 
built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) 
vision model and the [Gemma](https://arxiv.org/abs/2403.08295) language model. PaliGemma is designed as a versatile 
model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question 
answering, text reading, object detection and object segmentation.
\n\n
This space includes models fine-tuned on a mix of downstream tasks. 
See the [blog post](https://huggingface.co/blog/paligemma) and 
[README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) 
for detailed information how to use and fine-tune PaliGemma models.
\n\n
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications.
"""


make_image = lambda value, visible: gr.Image(
    value, label='Image', type='filepath', visible=visible)
make_annotated_image = functools.partial(gr.AnnotatedImage, label='Image')
make_highlighted_text = functools.partial(gr.HighlightedText, label='Output')


# https://coolors.co/4285f4-db4437-f4b400-0f9d58-e48ef1
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']


@gradio_helpers.synced
def compute(image, prompt, model_name, sampler):
  """Runs model inference."""
  if image is None:
    raise gr.Error('Image required')

  logging.info('prompt="%s"', prompt)

  if isinstance(image, str):
    image = PIL.Image.open(image)
  if gradio_helpers.should_mock():
    logging.warning('Mocking response')
    time.sleep(2.)
    output = paligemma_parse.EXAMPLE_STRING
  else:
    if not model_name:
      raise gr.Error('Models not loaded yet')
    output = models.generate(model_name, sampler, image, prompt)
    logging.info('output="%s"', output)

  width, height = image.size
  objs = paligemma_parse.extract_objs(output, width, height, unique_labels=True)
  labels = set(obj.get('name') for obj in objs if obj.get('name'))
  color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
  highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
  annotated_image = (
      image,
      [
          (
              obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
              obj['name'] or '',
          )
          for obj in objs
          if 'mask' in obj or 'xyxy' in obj
      ],
  )
  has_annotations = bool(annotated_image[1])
  return (
      make_highlighted_text(
          highlighted_text, visible=True, color_map=color_map),
      make_image(image, visible=not has_annotations),
      make_annotated_image(
          annotated_image, visible=has_annotations, width=width, height=height,
          color_map=color_map),
  )


def warmup(model_name):
  image = PIL.Image.new('RGB', [1, 1])
  _ = compute(image, '', model_name, 'greedy')


def reset():
  return (
      '', make_highlighted_text('', visible=False),
      make_image(None, visible=True), make_annotated_image(None, visible=False),
  )


def create_app():
  """Creates demo UI."""

  make_model = lambda choices: gr.Dropdown(
      value=(choices + [''])[0],
      choices=choices,
      label='Model',
      visible=bool(choices),
  )
  make_prompt = lambda value, visible=True: gr.Textbox(
      value, label='Prompt', visible=visible)

  with gr.Blocks() as demo:

    ##### Main UI structure.

    gr.Markdown(INTRO_TEXT)
    with gr.Row():
      image = make_image(None, visible=True)  # input
      annotated_image = make_annotated_image(None, visible=False)  # output
      with gr.Column():
        with gr.Row():
          prompt = make_prompt('', visible=True)
        model_info = gr.Markdown(label='Model Info')
        with gr.Row():
          model = make_model([])
          samplers = [
              'greedy', 'nucleus(0.1)', 'nucleus(0.3)', 'temperature(0.5)']
          sampler = gr.Dropdown(
              value=samplers[0], choices=samplers, label='Decoding'
          )
        with gr.Row():
          run = gr.Button('Run', variant='primary')
          clear = gr.Button('Clear')
        highlighted_text = make_highlighted_text('', visible=False)

    ##### UI logic.

    def update_ui(model, prompt):
      prompt = make_prompt(prompt, visible=True)
      model_info = f'Model `{model}` – {models.MODELS_INFO.get(model, "No info.")}'
      return [prompt, model_info]

    gr.on(
        [model.change],
        update_ui,
        [model, prompt],
        [prompt, model_info],
    )

    gr.on(
        [run.click, prompt.submit],
        compute,
        [image, prompt, model, sampler],
        [highlighted_text, image, annotated_image],
    )
    clear.click(
        reset, None, [prompt, highlighted_text, image, annotated_image]
    )

    ##### Examples.

    gr.set_static_paths(['examples/'])
    all_examples = [json.load(open(p)) for p in glob.glob('examples/*.json')]
    logging.info('loaded %d examples', len(all_examples))
    example_image = gr.Image(
        label='Image', visible=False)  # proxy, never visible
    example_model = gr.Text(
        label='Model', visible=False)  # proxy, never visible
    example_prompt = gr.Text(
        label='Prompt', visible=False)  # proxy, never visible
    example_license = gr.Markdown(
        label='Image License', visible=False)  # placeholder, never visible
    gr.Examples(
        examples=[
            [
                f'examples/{ex["name"]}.jpg',
                ex['prompt'],
                ex['model'],
                ex['license'],
            ]
            for ex in all_examples
            if ex['model'] in models.MODELS
        ],
        inputs=[example_image, example_prompt, example_model, example_license],
    )

    ##### Examples UI logic.

    example_image.change(
        lambda image_path: (
            make_image(image_path, visible=True),
            make_annotated_image(None, visible=False),
            make_highlighted_text('', visible=False),
        ),
        example_image,
        [image, annotated_image, highlighted_text],
    )
    def example_model_changed(model):
      if model not in gradio_helpers.get_paths():
        raise gr.Error(f'Model "{model}" not loaded!')
      return model
    example_model.change(example_model_changed, example_model, model)
    example_prompt.change(make_prompt, example_prompt, prompt)

    ##### Status.

    status = gr.Markdown(f'Startup: {datetime.datetime.now()}')
    gpu_kind = gr.Markdown(f'GPU=?')
    demo.load(
        lambda: [
            gradio_helpers.get_status(),
            make_model(list(gradio_helpers.get_paths())),
        ],
        None,
        [status, model],
    )
    def get_gpu_kind():
      device = jax.devices()[0]
      if not gradio_helpers.should_mock() and device.platform != 'gpu':
        raise gr.Error('GPU not visible to JAX!')
      return f'GPU={device.device_kind}'
    demo.load(get_gpu_kind, None, gpu_kind)

  return demo


if __name__ == '__main__':

  logging.basicConfig(level=logging.INFO,
                      format='%(asctime)s - %(levelname)s - %(message)s')

  logging.info('JAX devices: %s', jax.devices())

  for k, v in os.environ.items():
    logging.info('environ["%s"] = %r', k, v)

  gradio_helpers.set_warmup_function(warmup)
  for name, (repo, filename, revision) in models.MODELS.items():
    gradio_helpers.register_download(name, repo, filename, revision)

  create_app().queue().launch()
